from collections import defaultdict
import pybedtools
from numpy import *
from numpy.random import permutation
from scipy.stats import pearsonr, spearmanr, ttest_1samp, mannwhitneyu
from scipy.stats import f_oneway

from matplotlib.colors import LinearSegmentedColormap
from pylab import *


from rpy2 import robjects
import rpy2.robjects.numpy2ri
robjects.numpy2ri.activate()
from rpy2.robjects.packages import importr
deseq = importr('DESeq2')
print("Using DESeq2 version %s" % deseq.__version__)
glmGamPoi = importr('glmGamPoi')


import warnings
warnings.filterwarnings('error')


categories_skip = set(["antisense", "prompt", "antisense_distal", "antisense_distal_upstream", "roadmap_dyadic", "roadmap_enhancer", "FANTOM5_enhancer", "novel_enhancer_CAGE", "novel_enhancer_HiSeq", "other"])
categories_keep = set(["sense_proximal", "sense_upstream", "sense_distal", "sense_distal_upstream"])


timepoints = (0, 1, 4, 12, 24, 96)

def read_ppvalues():
    filename = "peaks.gff"
    print("Reading", filename)
    lines = pybedtools.BedTool(filename)
    ppvalues = {}
    for line in lines:
        feature = line.fields[2]
        category = feature
        if category in categories_skip:
            continue
        if category not in categories_keep:
            raise Exception("Unknown category %s" % category)
        peak = "%s_%d-%d_%s" % (line.chrom, line.start, line.end, line.strand)
        ppvalue = float(line.attrs['ppvalue'])
        ppvalues[peak] = ppvalue
    print("Read %d transcription initiation peaks" % len(ppvalues))
    return ppvalues

def calculate_dispersion(indices, counts):
    values = []
    for ii in indices:
        result = glmGamPoi.glm_gp(counts[ii], overdispersion='global')
        assert len(result.rx['overdispersions']) == 1
        overdispersions = result.rx['overdispersions'][0]
        overdispersion = set(overdispersions)
        assert len(overdispersion) == 1
        overdispersion = overdispersion.pop()
        values.append(overdispersion)
    return mean(values)

def read_expression_data(ppvalues):
    peaks = list(ppvalues.keys())
    filename = "peaks.expression.txt"
    print("Reading", filename)
    stream = open(filename)
    line = next(stream)
    words = line.split()
    assert words[0] == 'peak'
    assert words[1] == 'HiSeq_t00_r1'
    assert words[2] == 'HiSeq_t00_r2'
    assert words[3] == 'HiSeq_t00_r3'
    assert words[4] == 'HiSeq_t01_r1'
    assert words[5] == 'HiSeq_t01_r2'
    assert words[6] == 'HiSeq_t04_r1'
    assert words[7] == 'HiSeq_t04_r2'
    assert words[8] == 'HiSeq_t04_r3'
    assert words[9] == 'HiSeq_t12_r1'
    assert words[10] == 'HiSeq_t12_r2'
    assert words[11] == 'HiSeq_t12_r3'
    assert words[12] == 'HiSeq_t24_r1'
    assert words[13] == 'HiSeq_t24_r2'
    assert words[14] == 'HiSeq_t24_r3'
    assert words[15] == 'HiSeq_t96_r1'
    assert words[16] == 'HiSeq_t96_r2'
    assert words[17] == 'HiSeq_t96_r3'
    assert words[18] == 'CAGE_00_hr_A'
    assert words[19] == 'CAGE_00_hr_C'
    assert words[20] == 'CAGE_00_hr_G'
    assert words[21] == 'CAGE_00_hr_H'
    assert words[22] == 'CAGE_01_hr_A'
    assert words[23] == 'CAGE_01_hr_C'
    assert words[24] == 'CAGE_01_hr_G'
    assert words[25] == 'CAGE_04_hr_C'
    assert words[26] == 'CAGE_04_hr_E'
    assert words[27] == 'CAGE_12_hr_A'
    assert words[28] == 'CAGE_12_hr_C'
    assert words[29] == 'CAGE_24_hr_C'
    assert words[30] == 'CAGE_24_hr_E'
    assert words[31] == 'CAGE_96_hr_A'
    assert words[32] == 'CAGE_96_hr_C'
    assert words[33] == 'CAGE_96_hr_E'
    assert len(words) == 34
    hiseq_indices = [list() for timepoint in timepoints]
    cage_indices = [list() for timepoint in timepoints]
    for index, word in enumerate(words[1:]):
        condition, replicate = word.rsplit("_", 1)
        experiment, timepoint = condition.split("_", 1)
        if experiment == "HiSeq":
            assert replicate in ("r1", "r2", "r3")
            assert timepoint.startswith("t")
            timepoint = int(timepoint[1:])
            hiseq_indices[timepoints.index(timepoint)].append(index)
        elif experiment == "CAGE":
            assert replicate in "ACEGH"
            assert timepoint.endswith("_hr")
            timepoint = int(timepoint[:-3])
            cage_indices[timepoints.index(timepoint)].append(index)
        else:
            raise Exception("Unknown experiment %s" % experiment)
    for i, timepoint in enumerate(timepoints):
        hiseq_indices[i] = array(hiseq_indices[i])
    for i, timepoint in enumerate(timepoints):
        cage_indices[i] = array(cage_indices[i])
    n = sum([len(hiseq_indices[i])+len(cage_indices[i]) for i, timepoint in enumerate(timepoints)])
    assert n == 33
    i = 0
    n = len(peaks)
    data = []
    for line in stream:
        words = line.split()
        assert len(words) == 34
        peak = words[0]
        row = array(words[1:34], float)
        if peak == peaks[i]:
            data.append(row)
            i += 1
            if i == n:
                break
    else:
        raise ValueError("Failed to find all peaks")
    stream.close()
    data = array(data)
    return cage_indices, hiseq_indices, data

def make_figure_scatter(cage_indices, hiseq_indices, counts, factors, ppvalues):
    cmap = LinearSegmentedColormap.from_list("mycmap", ['blue', 'lightgray', 'red'])
    peaks = list(ppvalues.keys())
    ppvalues = array(list(ppvalues.values()))
    tpm = counts / factors
    total = mean(sum(tpm, 0))  # should be 1 million
    tpm *= 1.e6 / total  # rescale to get 1 million
    m = len(cage_indices)
    assert m == len(hiseq_indices)
    n = len(counts)
    cage_tpm = zeros((n, m))
    hiseq_tpm = zeros((n, m))
    for i, indices in enumerate(cage_indices):
        cage_tpm[:, i] = mean(tpm[:, indices], 1)
    for i, indices in enumerate(hiseq_indices):
        hiseq_tpm[:, i] = mean(tpm[:, indices], 1)
    cage_tpm = mean(cage_tpm, 1)
    hiseq_tpm = mean(hiseq_tpm, 1)
    assert len(cage_tpm) == n
    assert len(hiseq_tpm) == n
    print("Plotting a scatter plot with %d points" % n)
    fig = figure(figsize=(5,5))
    axes([0.20, 0.20, 0.55, 0.55])
    scatter(cage_tpm, hiseq_tpm, c=ppvalues, s=3, vmin=-12, vmax=+12, cmap=cmap)
    xscale('log')
    yscale('log')
    fig.axes[0].set_aspect('equal', adjustable='box')
    xmin, xmax = xlim()
    ymin, ymax = ylim()
    minimum = min(xmin, ymin)
    maximum = max(xmax, ymax)
    xlim(minimum, maximum)
    ylim(minimum, maximum)
    xticks(fontsize=8)
    yticks(fontsize=8)
    xlabel("Long capped RNA (CAGE libraries),\naverage expression [tpm]", fontsize=8)
    ylabel("Short capped RNA (single-end libraries),\naverage expression [tpm]", fontsize=8)
    cax = axes([0.80, 0.20, 0.03, 0.55])
    cb = colorbar(cax=cax)
    cb.set_label("$-\\log(p) \\times \\mathrm{sign}$", fontsize=8)
    cb.ax.tick_params(labelsize=8)
    r, p = spearmanr(cage_tpm, hiseq_tpm)
    print("CAGE average expression vs HiSeq average expression: Spearman correlation across genes is %.2f (p = %g)" % (r, p))
    r, p = pearsonr(log(cage_tpm+1), log(hiseq_tpm+1))
    print("CAGE average expression vs HiSeq average expression: Pearson correlation across genes is %.2f (p = %g)" % (r, p))
    filename = "figure_sense_timecourse_scatter.svg"
    print("Saving figure as %s" % filename)
    savefig(filename)
    filename = "figure_sense_timecourse_scatter.png"
    print("Saving figure as %s" % filename)
    savefig(filename)

def estimate_normalization_factors(counts, cage_indices, hiseq_indices):
    n, m = shape(counts)
    conditions_dataset = [None] * m
    conditions_timepoint = [None] * m
    for i, timepoint in enumerate(timepoints):
        timepoint = str(timepoint)
        for index in cage_indices[i]:
            conditions_dataset[index] = "CAGE"
            conditions_timepoint[index] = timepoint
        for index in hiseq_indices[i]:
            conditions_dataset[index] = "HiSeq"
            conditions_timepoint[index] = timepoint
    metadata = {'dataset': robjects.StrVector(conditions_dataset),
                'timepoint': robjects.StrVector(conditions_timepoint),
               }
    dataframe = robjects.DataFrame(metadata)
    design = robjects.Formula("~ timepoint + dataset")
    dds = deseq.DESeqDataSetFromMatrix(countData=counts,
                                       colData=dataframe,
                                       design=design)
    estimateSizeFactors = robjects.r['estimateSizeFactors']
    dds = estimateSizeFactors(dds)
    sizeFactors = robjects.r['sizeFactors']
    factors = sizeFactors(dds)
    return factors

ppvalues = read_ppvalues()

cage_indices, hiseq_indices, counts = read_expression_data(ppvalues)

factors = estimate_normalization_factors(counts, cage_indices, hiseq_indices)

make_figure_scatter(cage_indices, hiseq_indices, counts, factors, ppvalues)

hiseq_dispersion = calculate_dispersion(hiseq_indices, counts)
cage_dispersion = calculate_dispersion(cage_indices, counts)

dispersion = mean([cage_dispersion, hiseq_dispersion])
print("dispersion = %f" % dispersion)

factor = median(factors)
factors /= factor

data = counts / factors
data = (2 * arcsinh(sqrt(dispersion*data)) - log(dispersion) - log(4))/log(2)

indices = []
for i in range(len(data)):
    values = [data[i, ii] for ii in cage_indices]
    if max([max(row) - min(row) for row in values]) < 1.e-12:
        continue
    pvalue_cage = f_oneway(*values).pvalue
    values = [data[i, ii] for ii in hiseq_indices]
    if max([max(row) - min(row) for row in values]) < 1.e-12:
        continue
    pvalue_hiseq = f_oneway(*values).pvalue
    if pvalue_cage < 0.05 or pvalue_hiseq < 0.05:
        indices.append(i)

indices = array(indices, int)
print("Number of differentially expressed peaks: %d" % len(indices))
data = data[indices, :]

hiseq_mean_data = zeros((len(data), 6))
cage_mean_data = zeros((len(data), 6))

for i, indices in enumerate(cage_indices):
    cage_mean_data[:, i] = mean(data[:, indices], 1)

for i, indices in enumerate(hiseq_indices):
    hiseq_mean_data[:, i] = mean(data[:, indices], 1)

cage_mean_data = cage_mean_data.transpose()
cage_mean_data -= mean(cage_mean_data, 0)  
cage_mean_data = cage_mean_data.transpose()
hiseq_mean_data = hiseq_mean_data.transpose()
hiseq_mean_data -= mean(hiseq_mean_data, 0)
hiseq_mean_data = hiseq_mean_data.transpose()

x = cage_mean_data.flatten()
y = hiseq_mean_data.flatten()

fig = figure()

hexbin(x,y,bins='log',extent=(-8,8,-8,8))
# plot([-8,8],[-8,8],color='red',linestyle='--')
fig.axes[0].set_aspect('equal', adjustable='box')

xlim(-8,8)
ylim(-8,8)

xlabel("Long capped RNA (CAGE libraries),\nlog$_2$ expression fold change", fontsize=8)
ylabel("Short capped RNA (single-end libraries),\nlog$_2$ expression fold change", fontsize=8)
xticks(fontsize=8)
yticks(fontsize=8)

r, p = spearmanr(x, y)
print("Spearman correlation = %.2f (p = %g)" % (r, p))
r, p = pearsonr(x, y)
print("Pearson correlation = %.2f (p = %g)" % (r, p))

filename = "figure_sense_timecourse_logratios.svg"
print("Saving figure as %s" % filename)
savefig(filename)

filename = "figure_sense_timecourse_logratios.png"
print("Saving figure as %s" % filename)
savefig(filename)
